{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2To_0Q86toj4"
      },
      "source": [
        "# Optimistic Gradient Descent in a Bilinear Min-Max Problem\n",
        "\n",
        "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.sandbox.google.com/github/google-deepmind/optax/blob/main/examples/ogda_example.ipynb)\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MMvXmgsvTmcl"
      },
      "outputs": [],
      "source": [
        "import jax\n",
        "import jax.numpy as jnp\n",
        "import optax\n",
        "import matplotlib.pyplot as plt"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VqpOVQIgTuJ0"
      },
      "source": [
        "Consider the following min-max problem:\n",
        "\n",
        "$$\n",
        "\\min_{x \\in \\mathbb R^m} \\max_{y\\in\\mathbb R^n} f(x,y),\n",
        "$$\n",
        "\n",
        "where $f: \\mathbb R^m \\times \\mathbb R^n \\to \\mathbb R$ is a convex-concave function. The solution to such a problem is a saddle-point $(x^\\star, y^\\star)\\in \\mathbb R^m \\times \\mathbb R^n$ such that\n",
        "\n",
        "$$\n",
        "f(x^\\star, y) \\leq f(x^\\star, y^\\star) \\leq f(x, y^\\star).\n",
        "$$\n",
        "\n",
        "Standard gradient descent-ascent (GDA) updates $x$ and $y$ according to the following update rule at step $k$: \n",
        "\n",
        "$$\n",
        "x_{k+1} = x_k - \\eta_k \\nabla_x f(x_k, y_k) \\\\\n",
        "y_{k+1} = y_k + \\eta_k \\nabla_y f(x_k, y_k),\n",
        "$$\n",
        "\n",
        "where $\\eta_k$ is a step size. However, it's well-documented that GDA can fail to converge in this setting. This is an important issue because gradient-based min-max optimisation is increasingly prevalent in machine learning (e.g., GANs, constrained RL). *Optimistic* GDA (OGDA) addresses this shortcoming by introducing a form of memory-based negative momentum:  \n",
        "\n",
        "$$\n",
        "x_{k+1} = x_k - 2 \\eta_k \\nabla_x f(x_k, y_k) + \\eta_k  \\nabla_x f(x_{k-1}, y_{k-1})  \\\\\n",
        "y_{k+1} = y_k + 2 \\eta_k \\nabla_y f(x_k, y_k) - \\eta_k \\nabla_y f(x_{k-1}, y_{k-1})).\n",
        "$$\n",
        "\n",
        "Thus, to implement OGD (or OGA), the optimiser needs to keep track of the gradient from the previous step. OGDA has been formally shown to converge to the optimum $(x_k, y_k) \\to (x^\\star, y^\\star)$ in this setting. The generalised form of the OGDA update rule is given by\n",
        "\n",
        "$$\n",
        "x_{k+1} = x_k - (\\alpha + \\beta) \\eta_k \\nabla_x f(x_k, y_k) + \\beta \\eta_k \\nabla_x f(x_{k-1}, y_{k-1})  \\\\\n",
        "y_{k+1} = y_k + (\\alpha + \\beta) \\eta_k \\nabla_y f(x_k, y_k) - \\beta \\eta_k \\nabla_y f(x_{k-1}, y_{k-1})),\n",
        "$$\n",
        "\n",
        "which recovers standard OGDA when $\\alpha=\\beta=1$. See [Mokhtari et al., 2019](https://arxiv.org/abs/1901.08511v2) for more details. "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WxcGo14yWhWn"
      },
      "source": [
        "$$\n",
        "\\pi^{k+1} = \\pi^k - \\tau_\\pi^k \\nabla_\\pi \\mathcal L(\\pi^k, \\mu^k) \\\\\n",
        "\\mu^{k+1} = \\mu^k + \\tau_\\mu^k \\nabla_\\mu \\mathcal L(\\pi^k_k, \\mu^k),\n",
        "$$\n",
        "\n",
        "$$\n",
        "\\pi^{k+1} = \\pi^k - 2\\tau_\\pi^k \\nabla_\\pi \\mathcal L(\\pi^k, \\mu^k) +  \\tau_\\pi^k \\nabla_\\pi \\mathcal L(\\pi^{k-1}, \\mu^{k-1})\\\\\n",
        "\\mu^{k+1} = \\mu^k + 2\\tau_\\mu^k \\nabla_\\mu \\mathcal L(\\pi^k_k, \\mu^k)+  \\tau_\\mu^k \\nabla_\\mu \\mathcal L(\\pi^{k-1}, \\mu^{k-1})\n",
        "$$\n",
        "\n",
        "where $\\eta_k$ is a step size. However, it's well-documented that GDA can fail to converge in this setting. This is an important issue because gradient-based min-max optimisation is increasingly prevalent in machine learning (e.g., GANs, constrained RL). *Optimistic* GDA (OGDA) addresses this shortcoming by introducing a form of memory-based negative momentum:\n",
        "\n",
        "$$\n",
        "x_{k+1} = x_k - 2 \\eta_k \\nabla_x f(x_k, y_k) + \\eta_k  \\nabla_x f(x_{k-1}, y_{k-1})  \\\\\n",
        "y_{k+1} = y_k + 2 \\eta_k \\nabla_y f(x_k, y_k) - \\eta_k \\nabla_y f(x_{k-1}, y_{k-1})).\n",
        "$$"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nSyJyTSXszQ0"
      },
      "source": [
        "Define a bilinear min-max objective function: $\\min_x \\max_y xy$."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "snDy575-iDXw"
      },
      "outputs": [],
      "source": [
        "def f(params: jnp.ndarray) -> jnp.ndarray:\n",
        "  \"\"\"Objective: min_x max_y xy.\"\"\"\n",
        "  return params[\"x\"] * params[\"y\"]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "G-4JMKlgs-Lr"
      },
      "source": [
        "Define an optimisation loop."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MXXxtGs0qlfy"
      },
      "outputs": [],
      "source": [
        "def optimise(params: optax.Params, x_optimiser: optax.GradientTransformation, y_optimiser: optax.GradientTransformation, n_steps: int = 1000, display_every: int = 100) -> optax.Params:\n",
        "  \"\"\"An optimisation loop minimising x and maximising y.\"\"\"\n",
        "\n",
        "  x_opt_state = x_optimiser.init(params[\"x\"])\n",
        "  y_opt_state = y_optimiser.init(params[\"y\"])\n",
        "  param_hist = [params]\n",
        "  f_hist = []\n",
        "\n",
        "  @jax.jit\n",
        "  def step(params, x_opt_state, y_opt_state):\n",
        "    f_value, grads = jax.value_and_grad(f)(params)\n",
        "    x_update, x_opt_state = x_optimiser.update(grads[\"x\"], x_opt_state, params[\"x\"])\n",
        "    # note that we\"re maximising y so we feed in the negative gradient to the OGD update\n",
        "    y_update, y_opt_state = y_optimiser.update(-grads[\"y\"], y_opt_state, params[\"y\"])\n",
        "    updates = {\"x\": x_update, \"y\": y_update}\n",
        "    params = optax.apply_updates(params, updates)\n",
        "    return params, x_opt_state, y_opt_state, f_value\n",
        "\n",
        "  for k in range(n_steps):\n",
        "    params, x_opt_state, y_opt_state, f_value = step(params, x_opt_state, y_opt_state)\n",
        "    param_hist.append(params)\n",
        "    f_hist.append(f_value)\n",
        "    if k % display_every == 0:\n",
        "      print(f\"step {k}, f(x, y) = {f_value}, (x, y) = ({params['x']}, {params['y']})\")\n",
        "\n",
        "  return param_hist, f_hist"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gDtB7gJdtPZj"
      },
      "source": [
        "Initialise $x$ and $y$, as well as optimisers for each. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JvlhEUMat1PX"
      },
      "outputs": [],
      "source": [
        "initial_params = {\n",
        "    \"x\": jnp.array(1.0),\n",
        "    \"y\": jnp.array(1.0)\n",
        "}\n",
        "\n",
        "# GDA\n",
        "x_gd_optimiser = optax.sgd(learning_rate=0.1)\n",
        "y_ga_optimiser = optax.sgd(learning_rate=0.1)\n",
        "\n",
        "# OGDA\n",
        "x_ogd_optimiser = optax.optimistic_gradient_descent(learning_rate=0.1)\n",
        "y_oga_optimiser = optax.optimistic_gradient_descent(learning_rate=0.1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DF8oQEjLRO3a"
      },
      "source": [
        "Run each method."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "E-0_CtpjuEGi"
      },
      "outputs": [],
      "source": [
        "gda_hist, gda_f_hist = optimise(initial_params, x_gd_optimiser, y_ga_optimiser)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wBeMQEILwKsJ"
      },
      "outputs": [],
      "source": [
        "ogda_hist, ogda_f_hist = optimise(initial_params, x_ogd_optimiser, y_oga_optimiser)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "S504XrZrtXNe"
      },
      "source": [
        "Visualise the optimisation trajectories. The optimal solution is $(0, 0)$. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "S-XvwF9HujRT"
      },
      "outputs": [],
      "source": [
        "gda_xs, gda_ys = [p[\"x\"] for p in gda_hist], [p[\"y\"] for p in gda_hist]\n",
        "ogda_xs, ogda_ys = [p[\"x\"] for p in ogda_hist], [p[\"y\"] for p in ogda_hist]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WDgiytoDvX8N"
      },
      "outputs": [],
      "source": [
        "plt.plot(gda_xs, gda_ys, alpha=0.6, color=\"C0\", label=\"GDA\")\n",
        "plt.plot(ogda_xs, ogda_ys, alpha=0.6, color=\"C1\", label=\"OGDA\")\n",
        "plt.scatter([1], [1], color=\"r\", label=r\"$(x_0, y_0)$\", s=30)\n",
        "plt.scatter([0], [0], color=\"k\", label=r\"$(x^\\star, y^\\star)$\", s=30)\n",
        "plt.xlim([-2.0, 2.0])\n",
        "plt.ylim([-2.0, 2.0])\n",
        "plt.legend(loc=\"lower right\")\n",
        "plt.show()"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "",
        "kind": "private"
      },
      "name": "ogda_example.ipynb",
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "15hB4sFTdcHM7tf7l03PHZdwaP7AfdI8K",
          "timestamp": 1658099029490
        },
        {
          "file_id": "1Orjeh6PdEz2Vuj_XGGAStvkyck_zzT72",
          "timestamp": 1658096315530
        },
        {
          "file_id": "1v26CV4ivf38ZnYyd18PDnRFX4WMi-MX9",
          "timestamp": 1657803636509
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
